import os
import time
import math
import torch
import threading
import multiprocessing
import copy
import deepspeed
from PIL import Image
import torchvision
from image_synthesis.utils.misc import instantiate_from_config, format_seconds
from image_synthesis.distributed.distributed import reduce_dict
from image_synthesis.distributed.distributed import is_primary, get_rank
from image_synthesis.utils.misc import get_model_parameters_info
from image_synthesis.engine.lr_scheduler import ReduceLROnPlateauWithWarmup, CosineAnnealingLRWithWarmup
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from torch.optim.optimizer import Optimizer
from deepspeed.runtime.engine import DeepSpeedEngine

STEP_WITH_LOSS_SCHEDULERS = (ReduceLROnPlateauWithWarmup, ReduceLROnPlateau)

class Solver(object):
    def __init__(self, config, args, model, dataloader, logger):
        self.config = config
        self.args = args
        self.model = model 
        self.dataloader = dataloader
        self.logger = logger

        # initialize deepspeed engine earier
        self.optimizer_and_scheduler = self._get_optimizer_and_scheduler(config['solver']['optimizers_and_schedulers'])

        self.max_epochs = config['solver']['max_epochs']
        self.save_epochs = config['solver']['save_epochs']
        self.save_iterations = config['solver'].get('save_iterations', -1)
        self.sample_iterations = config['solver']['sample_iterations']
        if self.sample_iterations == 'epoch':
            self.sample_iterations = self.dataloader['train_iterations']
        self.validation_epochs = config['solver'].get('validation_epochs', 2)
        assert isinstance(self.save_epochs, (int, list))
        assert isinstance(self.validation_epochs, (int, list))

        self.last_epoch = -1
        self.ckpt_dir = os.path.join(self.args.save_dir, 'checkpoint')
        self.image_dir = os.path.join(self.args.save_dir, 'images')
        os.makedirs(self.ckpt_dir, exist_ok=True)
        os.makedirs(self.image_dir, exist_ok=True)

        # get lr
        adjust_lr = config['solver'].get('adjust_lr', 'sqrt')
        base_lr = config['solver'].get('base_lr', 1.0e-4)
        if adjust_lr == 'none':
            self.lr = base_lr
        elif adjust_lr == 'sqrt':
            self.lr = base_lr * math.sqrt(self.args.world_size * config['dataloader']['batch_size'])
        elif adjust_lr == 'linear':
            self.lr = base_lr * self.args.world_size * config['dataloader']['batch_size']
        else:
            raise NotImplementedError('Unknown type of adjust lr {}!'.format(adjust_lr))
        self.logger.log_info('Get lr {} from base lr {} with {}'.format(self.lr, base_lr, adjust_lr))


        self.logger.log_info(str(get_model_parameters_info(self.model)))
        self.device = self.model.device
        self.logger.log_info("{}: global rank {}: prepare solver done!".format(self.args.name,self.args.global_rank), check_primary=False)

    def _get_optimizer_and_scheduler(self, op_sc_list):
        assert len(op_sc_list) == 1, 'For DeepSpeed training, only one optimizer is supported!'
        optimizer_and_scheduler = {}

        op_sc = {
            'name': op_sc_list[0].get('name', 'none'),
        }

        # make parameters contiguous
        for n, p in self.model.named_parameters():
            p.data = p.data.contiguous()

        if op_sc['name'] == 'none':
            parameters = self.model.parameters()
        else:
            # NOTE: get the parameters with the given name, the parameters() should be overide
            parameters = self.model.parameters(name=op_sc['name'])
        
        # merge optimizer configs from solver to deepspeed
        if 'optimizer' in op_sc_list[0]:
            assert 'optimizer' not in self.config['deepspeed'], 'duplicated optimizer config!'
            op_cfg = op_sc_list[0]['optimizer']
            op_cfg['type'] = op_cfg['target'].split('.')[-1]
            del op_cfg['target']
            self.config['deepspeed']['optimizer'] = op_cfg
        
        if 'scheduler' in self.config['deepspeed']:
            raise RuntimeError('For optimizer scheduler, we leave it for solver, rather than deepspeed engine, since it is more flexible for solver!')

        self.config['deepspeed']['train_micro_batch_size_per_gpu'] = self.config['dataloader']['batch_size']
        self.config['deepspeed']['steps_per_print'] = 1e20 # a big number, so that we only log some info in this solver, and there is no need to log some info by the deepspeed engine

        self.model, optimizer, _, _ = deepspeed.initialize(
            args=self.args,
            model=self.model,
            model_parameters=parameters,
            config=self.config['deepspeed'],
        )

        self.args.local_rank = self.model.local_rank
        self.args.global_rank = get_rank()
        self.args.fp16 = self.model.fp16_enabled()     


        basic_optimizer = optimizer if isinstance(optimizer, Optimizer) else optimizer.optimizer

        op_sc['optimizer'] = {
            'module':  basic_optimizer
        }

        # build scheduler, leave scheduler to solver, rather than deepspeed engine
        if 'scheduler' in op_sc_list[0]:
            sc_cfg = op_sc_list[0]['scheduler']
            sc_cfg['params']['optimizer'] = basic_optimizer
            # for cosine annealing lr, compute T_max
            if sc_cfg['target'].split('.')[-1] in ['CosineAnnealingLRWithWarmup', 'CosineAnnealingLR']:
                T_max = self.max_epochs * self.dataloader['train_iterations']
                sc_cfg['params']['T_max'] = T_max
            scheduler = instantiate_from_config(sc_cfg)
            op_sc['scheduler'] = {
                'module': scheduler,
                'step_iteration': sc_cfg.get('step_iteration', 1)
            }
            if op_sc['scheduler']['step_iteration'] == 'epoch':
                op_sc['scheduler']['step_iteration'] = self.dataloader['train_iterations']

        optimizer_and_scheduler[op_sc['name']] = op_sc


        self.dataloader['train_iterations'] = self.dataloader['train_iterations'] // self.model.gradient_accumulation_steps()
        
        return optimizer_and_scheduler

    def _get_lr(self, return_type='str'):
        
        lrs = {}
        for op_sc_n, op_sc in self.optimizer_and_scheduler.items():
            lr = op_sc['optimizer']['module'].state_dict()['param_groups'][0]['lr']
            lrs[op_sc_n+'_lr'] = round(lr, 10)
        if return_type == 'str':
            lrs = str(lrs)
            lrs = lrs.replace('none', 'lr').replace('{', '').replace('}','').replace('\'', '')
        elif return_type == 'dict':
            pass 
        else:
            raise ValueError('Unknow of return type: {}'.format(return_type))
        return lrs

    def sample(self, batch, phase='train', step_type='iteration'):
        tic = time.time()
        self.logger.log_info('Begin to sample...')
        suffix = ''
        
        if isinstance(self.model, DeepSpeedEngine):
            model = self.model.module
        else:  
            model = self.model 
            
        with torch.no_grad(): 
            samples = model.sample(batch=batch)
            step = self.model.global_steps if step_type == 'iteration' else self.last_epoch
            for k, v in samples.items():
                save_dir = os.path.join(self.image_dir, phase, k)
                os.makedirs(save_dir, exist_ok=True)
                save_path = os.path.join(save_dir, 'e{:010d}_itr{:010d}_rank{}{}'.format(self.last_epoch, self.model.global_steps%self.dataloader['train_iterations'], get_rank(), suffix))
                if torch.is_tensor(v) and v.dim() == 4 and v.shape[1] in [1, 3]: # image
                    im = v
                    im = im.to(torch.uint8)
                    self.logger.add_images(tag='{}/{}e_{}itr/{}'.format(phase, self.last_epoch, self.model.global_steps%self.dataloader['train_iterations'], k), img_tensor=im, global_step=step, dataformats='NCHW')

                    # save images
                    im_grid = torchvision.utils.make_grid(im)
                    im_grid = im_grid.permute(1, 2, 0).to('cpu').numpy()
                    im_grid = Image.fromarray(im_grid)

                    im_grid.save(save_path + '.jpg')
                    self.logger.log_info('save {} to {}'.format(k, save_path+'.jpg'))
                else: # may be other values, such as text caption
                    with open(save_path+'.txt', 'a') as f:
                        f.write(str(v)+'\n')
                        f.close()
                    self.logger.log_info('save {} to {}'.format(k, save_path+'txt'))
    
        self.logger.log_info('Sample done, time: {:.2f}'.format(time.time() - tic))

    def step(self, batch, phase='train'):
        loss = {}
        for k, v in batch.items():
            if torch.is_tensor(v):
                v = v.cuda()
            
            if self.model.fp16_enabled() and isinstance(v, torch.FloatTensor):
                v = v.half()
            batch[k] = v 
        
        keys = list(self.optimizer_and_scheduler.keys())
        assert len(keys) == 1
        op_sc_n = keys[0]
        op_sc = self.optimizer_and_scheduler[op_sc_n]

        input = {
            'batch': batch,
            'return_loss': True,
            'step': self.model.global_steps,
            }
        if op_sc_n != 'none':
            input['name'] = op_sc_n

        if phase == 'train':
            output = self.model(**input)
        else:
            with torch.no_grad():
                output = self.model(**input)
        
        if phase == 'train':
            self.model.backward(output['loss'])
            self.model.step()

            # check scheduler
            if 'scheduler' in op_sc:
                if op_sc['scheduler']['step_iteration'] > 0 and (self.model.global_steps + 1) % op_sc['scheduler']['step_iteration'] == 0 and self.model.is_gradient_accumulation_boundary():
                    if isinstance(op_sc['scheduler']['module'], STEP_WITH_LOSS_SCHEDULERS):
                        op_sc['scheduler']['module'].step(output.get('loss'))
                    else:
                        op_sc['scheduler']['module'].step()

        loss[op_sc_n] = {k: v for k, v in output.items() if ('loss' in k or 'acc' in k)}
        return loss

    def save(self, force=False):

        # save with the epoch specified name
        if self.save_iterations > 0:
            if (self.model.global_steps + 1) % self.save_iterations == 0 and self.model.is_gradient_accumulation_boundary():
                save = True 
            else:
                save = False
        else:
            if isinstance(self.save_epochs, int):
                save = (self.last_epoch + 1) % self.save_epochs == 0
            else:
                assert isinstance(self.save_epochs, list), 'the save epochs can only be int or list of int!'
                save = (self.last_epoch + 1) in self.save_epochs
            
        if save or force:
            state_dict = {
                'last_epoch': self.last_epoch,
            }

            # add schedulers
            scheduler = {}
            for op_sc_n, op_sc in self.optimizer_and_scheduler.items():
                state_ = {}
                for k in op_sc:
                    if k in ['scheduler']:
                        op_or_sc = {kk: vv for kk, vv in op_sc[k].items() if kk != 'module'}
                        op_or_sc['module'] = op_sc[k]['module'].state_dict()
                        state_[k] = op_or_sc
                    else:
                        state_[k] = op_sc[k]
                scheduler[op_sc_n] = state_

            state_dict['scheduler'] = scheduler

            if save:
                ckpt_id = '{}e_{}iter'.format(str(self.last_epoch).zfill(6), self.model.global_steps)
                success = self.model.save_checkpoint(self.ckpt_dir, ckpt_id, state_dict)
                status_msg = 'saved to PATH={}, ckpt_id={}'.format(self.ckpt_dir, ckpt_id)
                if success:
                    self.logger.log_info('Successed: {}'.format(status_msg))  
                else:
                    self.logger.log_info('Failed: {}'.format(status_msg))    
            
            # save with the last name
            ckpt_id = 'last'
            success = self.model.save_checkpoint(self.ckpt_dir, ckpt_id, state_dict)
            status_msg = 'saved to ditrectory={}, ckpt_id={}'.format(self.ckpt_dir, ckpt_id)
            if success:
                self.logger.log_info('Successed: {}'.format(status_msg))  
            else:
                self.logger.log_info('Failed: {}'.format(status_msg))     
    
    def resume(self, 
               ckpt_id=None, # The path of last.pth
               load_optimizer_and_scheduler=True, # whether to load optimizers and scheduler
               load_others=True # load other informations
               ): 
        if ckpt_id is None:
            ckpt_id = 'last'

        _, state_dict = self.model.load_checkpoint(
            self.ckpt_dir, 
            ckpt_id,
            load_optimizer_states=load_optimizer_and_scheduler,
            load_lr_scheduler_states=load_optimizer_and_scheduler)

        # load scheduler
        scheduler_static = state_dict['scheduler']
        for op_sc_n in scheduler_static:
            state_ = scheduler_static[op_sc_n]
            for k in state_:
                if k == 'module':
                    self.optimizer_and_scheduler[op_sc_n]['scheduler']['module'].load_state_dict(state_[k])
                else:
                    self.optimizer_and_scheduler[op_sc_n]['scheduler'][k] = state_[k]


        if load_others:
            self.last_epoch = state_dict['last_epoch']
        
        self.logger.log_info('Resume from ditrectory={}, ckpt_id={}'.format(self.ckpt_dir, ckpt_id))
    
    def train_epoch(self):
        self.model.train()
        self.last_epoch += 1
        self.logger.log_info("global rank: {}, started epoch {}".format(self.args.global_rank, self.last_epoch), check_primary=False)

        if self.args.distributed:
            self.dataloader['train_loader'].sampler.set_epoch(self.last_epoch)

        epoch_start = time.time()
        itr_start = time.time()
        itr = -1
        for itr, batch in enumerate(self.dataloader['train_loader']):
            data_time = time.time() - itr_start
            step_start = time.time()

            loss = self.step(batch, phase='train')

            # logging info
            if self.logger is not None and self.model.global_steps % self.args.log_frequency == 0 and self.model.is_gradient_accumulation_boundary():
                info = '{}: train'.format(self.args.name)
                info = info + ': Epoch {}/{} iter {}/{}'.format(self.last_epoch, self.max_epochs, self.model.global_steps%self.dataloader['train_iterations'], self.dataloader['train_iterations'])
                for loss_n, loss_dict in loss.items():
                    info += ' ||'
                    loss_dict = reduce_dict(loss_dict)
                    info += '' if loss_n == 'none' else ' {}'.format(loss_n)
                    # info = info + ': Epoch {}/{} iter {}/{}'.format(self.last_epoch, self.max_epochs, self.model.global_steps%self.dataloader['train_iterations'], self.dataloader['train_iterations'])
                    for k in loss_dict:
                        info += ' | {}: {:.4f}'.format(k, float(loss_dict[k]))
                        self.logger.add_scalar(tag='train/{}/{}'.format(loss_n, k), scalar_value=float(loss_dict[k]), global_step=self.model.global_steps)
                
                # log lr
                lrs = self._get_lr(return_type='dict')
                for k in lrs.keys():
                    lr = lrs[k]
                    self.logger.add_scalar(tag='train/{}_lr'.format(k), scalar_value=lrs[k], global_step=self.model.global_steps)

                # add lr to info
                info += ' || {}'.format(self._get_lr())
                
                info += ' || skipped_iters: {}'.format(self.model.skipped_steps)

                # add time consumption to info
                spend_time = time.time() - self.start_train_time
                itr_time_avg = spend_time / (self.model.global_steps + 1)
                info += ' || data_time: {dt}s | fbward_time: {fbt}s | iter_time: {it}s | iter_avg_time: {ita}s | epoch_time: {et} | spend_time: {st} | left_time: {lt}'.format(
                        dt=round(data_time, 1),
                        it=round(time.time() - itr_start, 1),
                        fbt=round(time.time() - step_start, 1),
                        ita=round(itr_time_avg, 1),
                        et=format_seconds(time.time() - epoch_start),
                        st=format_seconds(spend_time),
                        lt=format_seconds(itr_time_avg*self.max_epochs*self.dataloader['train_iterations']-spend_time)
                        )
                self.logger.log_info(info)
            
            itr_start = time.time()

            # sample
            if self.sample_iterations > 0 and (self.model.global_steps + 1) % self.sample_iterations == 0 and self.model.is_gradient_accumulation_boundary():
                self.model.eval()
                self.sample(batch, phase='train', step_type='iteration')
                if 'validation_loader' in self.dataloader:
                    for _, batch_val in enumerate(self.dataloader['validation_loader']):
                        self.sample(batch_val, phase='val', step_type='iteration')
                        break
                self.model.train()

            # save model
            if self.save_iterations > 0:
                self.save(force=False)

        # modify here to make sure dataloader['train_iterations'] is correct
        assert itr >= 0, "The data is too less to form one iteration!"
        self.dataloader['train_iterations'] = (itr + 1) // self.model.gradient_accumulation_steps()

    def validate_epoch(self):
        if 'validation_loader' not in self.dataloader:
            val = False
        else:
            if isinstance(self.validation_epochs, int):
                val = (self.last_epoch + 1) % self.validation_epochs == 0
            else:
                val = (self.last_epoch + 1) in self.validation_epochs        
        
        if val:
            if self.args.distributed:
                self.dataloader['validation_loader'].sampler.set_epoch(self.last_epoch)

            self.model.eval()
            overall_loss = None
            epoch_start = time.time()
            itr_start = time.time()
            itr = -1
            for itr, batch in enumerate(self.dataloader['validation_loader']):
                data_time = time.time() - itr_start
                step_start = time.time()
                loss = self.step(batch, phase='val')
                
                for loss_n, loss_dict in loss.items():
                    loss[loss_n] = reduce_dict(loss_dict)
                if overall_loss is None:
                    overall_loss = loss
                else:
                    for loss_n, loss_dict in loss.items():
                        for k, v in loss_dict.items():
                            overall_loss[loss_n][k] = (overall_loss[loss_n][k] * itr + loss[loss_n][k]) / (itr + 1)
                
                if self.logger is not None and (itr+1) % self.args.log_frequency == 0:
                    info = '{}: val'.format(self.args.name) 
                    info = info + ': Epoch {}/{} | iter {}/{}'.format(self.last_epoch, self.max_epochs, itr, self.dataloader['validation_iterations'])
                    for loss_n, loss_dict in loss.items():
                        info += ' ||'
                        info += '' if loss_n == 'none' else ' {}'.format(loss_n)
                        # info = info + ': Epoch {}/{} | iter {}/{}'.format(self.last_epoch, self.max_epochs, itr, self.dataloader['validation_iterations'])
                        for k in loss_dict:
                            info += ' | {}: {:.4f}'.format(k, float(loss_dict[k]))
                        
                    itr_time_avg = (time.time() - epoch_start) / (itr + 1)
                    info += ' || data_time: {dt}s | fbward_time: {fbt}s | iter_time: {it}s | epoch_time: {et} | left_time: {lt}'.format(
                            dt=round(data_time, 1),
                            fbt=round(time.time() - step_start, 1),
                            it=round(time.time() - itr_start, 1),
                            et=format_seconds(time.time() - epoch_start),
                            lt=format_seconds(itr_time_avg*(self.dataloader['train_iterations']-itr-1))
                            )
                        
                    self.logger.log_info(info)
                itr_start = time.time()
            # modify here to make sure dataloader['validation_iterations'] is correct
            assert itr >= 0, "The data is too less to form one iteration!"
            self.dataloader['validation_iterations'] = itr + 1

            if self.logger is not None:
                info = '{}: val'.format(self.args.name) 
                for loss_n, loss_dict in overall_loss.items():
                    info += '' if loss_n == 'none' else ' {}'.format(loss_n)
                    info += ': Epoch {}/{}'.format(self.last_epoch, self.max_epochs)
                    for k in loss_dict:
                        info += ' | {}: {:.4f}'.format(k, float(loss_dict[k]))
                        self.logger.add_scalar(tag='val/{}/{}'.format(loss_n, k), scalar_value=float(loss_dict[k]), global_step=self.last_epoch)
                self.logger.log_info(info)

    def validate(self):
        self.validation_epoch()

    def train(self):
        start_epoch = self.last_epoch + 1
        self.start_train_time = time.time()
        self.logger.log_info('{}: global rank {}: start training...'.format(self.args.name, self.args.global_rank), check_primary=False)
        for epoch in range(start_epoch, self.max_epochs):
            self.train_epoch()
            self.save(force=True)
            self.validate_epoch()
            

